{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "207a7f06-e0dd-4d0b-8761-99ad25c790c2",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "x_1 &= x_0 - \\text{mean}(x_0)\\\\\n",
    "x_2 &= \\frac{x_1}{\\sqrt{\\text{mean}(x_1^2)}}\\\\\n",
    "x_3 &= x_2 \\cdot w\\\\\n",
    "x_4 &= x_3 + b\\\\\n",
    "\\end{split}\n",
    "$$\n",
    "\n",
    "- Unlike BatchNorm, it cannot be turned off at inference time, as it significantly alters the mathematical function implemented by the transformer.\n",
    "- 定义在 feature 维度，样本甚至是 token 级别，而非 batchnorm 的跨样本；\n",
    "- element-wise 操作；"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "76a755e2-9348-4135-9e1f-5ab8851c1a7c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:20:31.028856Z",
     "iopub.status.busy": "2024-12-19T14:20:31.028199Z",
     "iopub.status.idle": "2024-12-19T14:20:31.042105Z",
     "shell.execute_reply": "2024-12-19T14:20:31.039936Z",
     "shell.execute_reply.started": "2024-12-19T14:20:31.028809Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x774eeaaa5690>"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "e451a3e7-f801-41e4-aa13-f8f8cc48b85b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:20:32.493381Z",
     "iopub.status.busy": "2024-12-19T14:20:32.492778Z",
     "iopub.status.idle": "2024-12-19T14:20:32.501645Z",
     "shell.execute_reply": "2024-12-19T14:20:32.499494Z",
     "shell.execute_reply.started": "2024-12-19T14:20:32.493335Z"
    }
   },
   "outputs": [],
   "source": [
    "batch, sentence_length, embedding_dim = 2, 3, 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "25a3b546-7b63-445a-aa16-49ff6615009e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:20:33.562922Z",
     "iopub.status.busy": "2024-12-19T14:20:33.562326Z",
     "iopub.status.idle": "2024-12-19T14:20:33.578811Z",
     "shell.execute_reply": "2024-12-19T14:20:33.576355Z",
     "shell.execute_reply.started": "2024-12-19T14:20:33.562877Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[2., 2., 1., 4.],\n",
       "         [1., 0., 0., 4.],\n",
       "         [0., 3., 3., 4.]],\n",
       "\n",
       "        [[0., 4., 1., 2.],\n",
       "         [0., 0., 2., 1.],\n",
       "         [4., 1., 3., 1.]]])"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding = torch.randint(0, 5, (batch, sentence_length, embedding_dim)).float()\n",
    "embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "e2e4e0fc-5050-4a2c-b156-3a5824e35632",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:20:35.865423Z",
     "iopub.status.busy": "2024-12-19T14:20:35.864827Z",
     "iopub.status.idle": "2024-12-19T14:20:35.875125Z",
     "shell.execute_reply": "2024-12-19T14:20:35.872957Z",
     "shell.execute_reply.started": "2024-12-19T14:20:35.865378Z"
    }
   },
   "outputs": [],
   "source": [
    "ln = nn.LayerNorm(embedding_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "a463ac73-23f4-4f5b-8746-e6b5642c9a97",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:20:37.263538Z",
     "iopub.status.busy": "2024-12-19T14:20:37.262944Z",
     "iopub.status.idle": "2024-12-19T14:20:37.277615Z",
     "shell.execute_reply": "2024-12-19T14:20:37.275507Z",
     "shell.execute_reply.started": "2024-12-19T14:20:37.263494Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Parameter containing:\n",
       " tensor([1., 1., 1., 1.], requires_grad=True),\n",
       " Parameter containing:\n",
       " tensor([0., 0., 0., 0.], requires_grad=True)]"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(ln.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "a5b43758-d1da-4dd8-a17d-62040bb894e1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:23:02.140722Z",
     "iopub.status.busy": "2024-12-19T14:23:02.140125Z",
     "iopub.status.idle": "2024-12-19T14:23:02.154222Z",
     "shell.execute_reply": "2024-12-19T14:23:02.152121Z",
     "shell.execute_reply.started": "2024-12-19T14:23:02.140678Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([1., 1., 1., 1.], requires_grad=True),\n",
       " Parameter containing:\n",
       " tensor([0., 0., 0., 0.], requires_grad=True))"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ln.weight, ln.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "74a4438c-4d0c-462b-a285-6dd0daba2289",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:25:10.690907Z",
     "iopub.status.busy": "2024-12-19T14:25:10.690240Z",
     "iopub.status.idle": "2024-12-19T14:25:10.702279Z",
     "shell.execute_reply": "2024-12-19T14:25:10.700908Z",
     "shell.execute_reply.started": "2024-12-19T14:25:10.690860Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.2294, -0.2294, -1.1471,  1.6059],\n",
       "         [-0.1525, -0.7625, -0.7625,  1.6775],\n",
       "         [-1.6667,  0.3333,  0.3333,  1.0000]],\n",
       "\n",
       "        [[-1.1832,  1.5213, -0.5071,  0.1690],\n",
       "         [-0.9045, -0.9045,  1.5075,  0.3015],\n",
       "         [ 1.3471, -0.9622,  0.5773, -0.9622]]],\n",
       "       grad_fn=<NativeLayerNormBackward0>)"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ln(embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "badc0d85-3562-42c0-936b-08e7d9d1072b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:26:50.031528Z",
     "iopub.status.busy": "2024-12-19T14:26:50.030890Z",
     "iopub.status.idle": "2024-12-19T14:26:50.047679Z",
     "shell.execute_reply": "2024-12-19T14:26:50.045560Z",
     "shell.execute_reply.started": "2024-12-19T14:26:50.031467Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.2294, -0.2294, -1.1471,  1.6059], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = embedding[0, 0, :]\n",
    "x1 = x - torch.mean(x)\n",
    "x2 = x1 / torch.sqrt(torch.var(x1, unbiased=False))\n",
    "x3 = x2 * ln.weight + ln.bias\n",
    "x3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "af2e56ec-9ac1-4365-ac29-92b9cd20e91d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:29:44.466924Z",
     "iopub.status.busy": "2024-12-19T14:29:44.466266Z",
     "iopub.status.idle": "2024-12-19T14:29:44.475613Z",
     "shell.execute_reply": "2024-12-19T14:29:44.473354Z",
     "shell.execute_reply.started": "2024-12-19T14:29:44.466876Z"
    }
   },
   "outputs": [],
   "source": [
    "gamma = ln.weight\n",
    "beta = ln.bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "26fa9951-c58f-42ab-a86c-4ae43f8dea0d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:30:04.667088Z",
     "iopub.status.busy": "2024-12-19T14:30:04.666424Z",
     "iopub.status.idle": "2024-12-19T14:30:04.683057Z",
     "shell.execute_reply": "2024-12-19T14:30:04.681070Z",
     "shell.execute_reply.started": "2024-12-19T14:30:04.667040Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.2294, -0.2294, -1.1471,  1.6059],\n",
       "         [-0.1525, -0.7625, -0.7625,  1.6775],\n",
       "         [-1.6667,  0.3333,  0.3333,  1.0000]],\n",
       "\n",
       "        [[-1.1832,  1.5213, -0.5071,  0.1690],\n",
       "         [-0.9045, -0.9045,  1.5075,  0.3015],\n",
       "         [ 1.3471, -0.9622,  0.5773, -0.9622]]], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "normalized = (embedding - embedding.mean(dim=-1, keepdim=True)) / torch.sqrt(embedding.var(dim=-1, unbiased=False, keepdim=True) + ln.eps) \n",
    "gamma * normalized + beta"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
